# datasets/synthetic.py
from __future__ import annotations

from typing import Optional, Dict, Any

import numpy as np
from sklearn.datasets import make_swiss_roll

from .base import DatasetSpec
from .registry import register
from .transforms import default_preprocess

@register("swiss_roll")
def load_swiss_roll(
    cache_dir: Optional[str] = None,
    *,
    n_samples: int = 8000,
    noise: float = 0.0,
    preprocess: bool = True,
    pca_n: Optional[int] = 0,   # keep as 3D ambient by default
    random_state: int = 0,
) -> DatasetSpec:
    """
    3D Swiss-roll with known 2D manifold; label = unrolled parameter.
    Great for topology/continuity checks and for visualization tests.
    """
    X, t = make_swiss_roll(n_samples=n_samples, noise=noise, random_state=random_state)
    X = X.astype(np.float32)
    labels = (t / t.max()).astype(np.float32)  # scaled 0..1

    meta: Dict[str, Any] = {"ambient_dim": 3, "swiss_t": labels.copy()}
    if preprocess:
        Xp, info = default_preprocess(
            X, normalize=False, log1p=False, hvg_n=None, pca_n=pca_n, random_state=random_state
        )
        meta.update(info)
        return DatasetSpec(name="swiss_roll", X=Xp, labels=labels, meta=meta)
    return DatasetSpec(name="swiss_roll", X=X, labels=labels, meta=meta)
